# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Example script for exporting models to flatbuffer with the Vulkan delegate

# pyre-unsafe

import argparse
import logging
import os

import executorch.backends.vulkan.test.utils as test_utils
import torch
import torchvision
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
from executorch.devtools import BundledProgram
from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite
from executorch.devtools.bundled_program.serialize import (
    serialize_from_bundled_program_to_flatbuffer,
)
from executorch.examples.models import MODEL_NAME_TO_MODEL
from executorch.examples.models.model_factory import EagerModelFactory
from executorch.exir import to_edge_transform_and_lower
from executorch.extension.export_util.utils import save_pte_program
from executorch.extension.pytree import tree_flatten
from torch.export import Dim, export

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)

import urllib


def is_vision_model(model_name):
    if model_name in [
        # These models are also registered in examples/models
        "dl3",
        "edsr",
        "mv2",
        "mv3",
        "vit",
        "ic3",
        "ic4",
        "resnet18",
        "resnet50",
        # These models are not registered in examples/models but are available via
        # torchvision
        "convnext_small",
        "densenet161",
        "shufflenet_v2_x1_0",
    ]:
        return True

    return False


def get_vision_model_sample_input():
    return (torch.randn(1, 3, 224, 224),)


def get_vision_model_dynamic_shapes():
    return (
        {
            2: Dim("height", min=1, max=16) * 16,
            3: Dim("width", min=1, max=16) * 16,
        },
    )


def get_dog_image_tensor(image_size=224, normalization="imagenet"):
    url, filename = (
        "https://github.com/pytorch/hub/raw/master/images/dog.jpg",
        "dog.jpg",
    )
    try:
        urllib.URLopener().retrieve(url, filename)
    except:
        urllib.request.urlretrieve(url, filename)

    from PIL import Image
    from torchvision import transforms

    input_image = Image.open(filename).convert("RGB")

    transforms_list = [
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
    ]
    if normalization == "imagenet":
        transforms_list.append(
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        )

    preprocess = transforms.Compose(transforms_list)

    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)
    input_batch = (input_batch,)
    return input_batch


def init_model(model_name):
    if model_name == "convnext_small":
        return torchvision.models.convnext_small()
    if model_name == "densenet161":
        return torchvision.models.densenet161()
    if model_name == "shufflenet_v2_x1_0":
        return torchvision.models.shufflenet_v2_x1_0()
    if model_name == "YOLO_NAS_S":
        try:
            from super_gradients.common.object_names import Models
            from super_gradients.training import models
        except ImportError:
            raise ImportError(
                "Please install super-gradients to use the YOLO_NAS_S model."
            )

        return models.get(Models.YOLO_NAS_S, pretrained_weights="coco")

    return None


def get_sample_inputs(model_name):
    # Lock the random seed for reproducibility
    torch.manual_seed(42)

    if is_vision_model(model_name):
        return get_vision_model_sample_input()
    if model_name == "YOLO_NAS_S":
        input_batch = get_dog_image_tensor(640)
        return input_batch

    return None


def get_dynamic_shapes(model_name):
    if is_vision_model(model_name):
        return get_vision_model_dynamic_shapes()

    return None


def main() -> None:  # noqa: C901
    logger = logging.getLogger("")
    logger.setLevel(logging.INFO)

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-m",
        "--model_name",
        required=True,
        help=f"provide a model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}",
    )

    parser.add_argument(
        "-fp16",
        "--force_fp16",
        action=argparse.BooleanOptionalAction,
        default=False,
        help="Force fp32 tensors to be converted to fp16 internally. Input/s outputs "
        "will be converted to/from fp32 when entering/exiting the delegate. Default is "
        "False",
    )

    parser.add_argument(
        "--small_texture_limits",
        action=argparse.BooleanOptionalAction,
        default=False,
        help="sets the default texture limit to be (2048, 2048, 2048) which is "
        "compatible with more devices (i.e. desktop/laptop GPUs) compared to the "
        "default (16384, 16384, 2048) which is more targeted for mobile GPUs. Default "
        "is False.",
    )

    parser.add_argument(
        "--skip_memory_planning",
        action=argparse.BooleanOptionalAction,
        default=False,
        help="Skips memory planning pass while lowering, which can be used for "
        "debugging. Default is False.",
    )

    parser.add_argument(
        "-s",
        "--strict",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="whether to export with strict mode. Default is True",
    )

    parser.add_argument(
        "-d",
        "--dynamic",
        action=argparse.BooleanOptionalAction,
        default=False,
        help="Enable dynamic shape support. Default is False",
    )

    parser.add_argument(
        "-r",
        "--etrecord",
        required=False,
        default="",
        help="Generate and save an ETRecord to the given file location",
    )

    parser.add_argument("-o", "--output_dir", default=".", help="output directory")

    parser.add_argument(
        "-b",
        "--bundled",
        action=argparse.BooleanOptionalAction,
        default=False,
        help="Export as bundled program (.bpte) instead of regular program (.pte). Default is False",
    )

    parser.add_argument(
        "-t",
        "--test",
        action=argparse.BooleanOptionalAction,
        default=False,
        help="Execute lower_module_and_test_output to validate the model. Default is False",
    )

    parser.add_argument(
        "--save_inputs",
        action=argparse.BooleanOptionalAction,
        default=False,
        help="Whether to save the inputs to the model. Default is False",
    )

    args = parser.parse_args()

    if args.model_name in MODEL_NAME_TO_MODEL:
        model, example_inputs, _, dynamic_shapes = EagerModelFactory.create_model(
            *MODEL_NAME_TO_MODEL[args.model_name]
        )
    else:
        model = init_model(args.model_name)
        example_inputs = get_sample_inputs(args.model_name)
        dynamic_shapes = get_dynamic_shapes(args.model_name) if args.dynamic else None

        if model is None:
            raise RuntimeError(
                f"Model {args.model_name} is not a valid name. "
                f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
            )

    # Prepare model
    model.eval()

    # Setup compile options
    compile_options = {}
    if args.dynamic:
        compile_options["require_dynamic_shapes"] = True
        # Try to manually get the dynamic shapes for the model if not set
        if dynamic_shapes is None:
            dynamic_shapes = get_dynamic_shapes(args.model_name)

    if args.force_fp16:
        compile_options["force_fp16"] = True
    if args.skip_memory_planning:
        compile_options["skip_memory_planning"] = True
    if args.small_texture_limits:
        compile_options["small_texture_limits"] = True

    logging.info(f"Exporting model {args.model_name} with Vulkan delegate")

    # Export the model using torch.export
    if dynamic_shapes is not None:
        program = export(
            model, example_inputs, dynamic_shapes=dynamic_shapes, strict=args.strict
        )
    else:
        program = export(model, example_inputs, strict=args.strict)

    # Transform and lower with Vulkan partitioner
    edge_program = to_edge_transform_and_lower(
        program,
        partitioner=[VulkanPartitioner(compile_options)],
        generate_etrecord=args.etrecord,
    )

    logging.info(
        f"Exported and lowered graph:\n{edge_program.exported_program().graph}"
    )

    # Create executorch program
    exec_prog = edge_program.to_executorch()

    # Save ETRecord if requested
    if args.etrecord:
        exec_prog.get_etrecord().save(args.etrecord)
        logging.info(f"Saved ETRecord to {args.etrecord}")

    # Save the program
    output_filename = f"{args.model_name}_vulkan"

    atol = 1e-4
    rtol = 1e-4

    # If forcing fp16, then numerical divergence is expected
    if args.force_fp16:
        atol = 2e-2
        rtol = 1e-1

    # Save regular program
    save_pte_program(exec_prog, output_filename, args.output_dir)
    logging.info(
        f"Model exported and saved as {output_filename}.pte in {args.output_dir}"
    )

    if args.save_inputs:
        inputs_flattened, _ = tree_flatten(example_inputs)
        for i, input_tensor in enumerate(inputs_flattened):
            input_filename = os.path.join(args.output_dir, f"input{i}.bin")
            input_tensor.numpy().tofile(input_filename)
            f"Model input saved as {input_filename} in {args.output_dir}"

    if args.bundled:
        # Create bundled program
        logging.info("Creating bundled program with test cases")

        # Generate expected outputs by running the model
        expected_outputs = [model(*example_inputs)]

        # Flatten sample inputs to match expected format
        inputs_flattened, _ = tree_flatten(example_inputs)

        # Create test suite with the sample inputs and expected outputs
        test_suites = [
            MethodTestSuite(
                method_name="forward",
                test_cases=[
                    MethodTestCase(
                        inputs=inputs_flattened,
                        expected_outputs=expected_outputs,
                    )
                ],
            )
        ]

        # Create bundled program
        bp = BundledProgram(exec_prog, test_suites)

        # Serialize to flatbuffer
        bp_buffer = serialize_from_bundled_program_to_flatbuffer(bp)

        # Save bundled program
        bundled_output_path = f"{args.output_dir}/{output_filename}.bpte"
        with open(bundled_output_path, "wb") as file:
            file.write(bp_buffer)

        logging.info(
            f"Bundled program exported and saved as {output_filename}.bpte in {args.output_dir}"
        )

    # Test the model if --test flag is provided
    if args.test:
        test_result = test_utils.run_and_check_output(
            reference_model=model,
            executorch_program=exec_prog,
            sample_inputs=example_inputs,
            atol=atol,
            rtol=rtol,
        )

        if test_result:
            logging.info(
                "✓ Model test PASSED - outputs match reference within tolerance"
            )
        else:
            logging.error("✗ Model test FAILED - outputs do not match reference")
            raise RuntimeError(
                "Model validation failed: ExecuTorch outputs do not match reference model outputs"
            )


if __name__ == "__main__":
    with torch.no_grad():
        main()  # pragma: no cover
